This analysis explores EEG spectrogram embeddings to predict expert consensus classifications. We’ll implement:
# Function to install missing packages
install_if_missing <- function(packages) {
new_packages <- packages[!(packages %in% installed.packages()[,"Package"])]
if(length(new_packages)) {
repos <- "https://cloud.r-project.org"
install.packages(new_packages, dependencies = TRUE, repos = repos)
}
}
# List of required packages
required_packages <- c("tidyverse", "caret", "ranger", "lightgbm", "e1071",
"umap", "ggplot2", "corrplot", "doParallel", "pROC",
"reshape2", "viridis", "gridExtra", "nnet") # Add nnet for multinom
# Install missing packages
install_if_missing(required_packages)
# Load necessary libraries
library(tidyverse)
library(caret)
library(ranger)
library(lightgbm)
library(e1071)
library(umap)
library(ggplot2)
library(corrplot)
library(doParallel)
library(pROC)
library(reshape2)
library(viridis)
library(gridExtra)
# Set seed for reproducibility
set.seed(42)
# Set up parallel processing
registerDoParallel(cores = parallel::detectCores() - 1)# Using file.path for OS-independent path handling
file_path <- file.path("..", "out", "test", "vit_embeddings_test.csv")
data <- read.csv(file_path) # Assign to 'data' variable
# Print successful data load
cat("Successfully loaded data with", nrow(data), "rows and", ncol(data), "columns\n")## Successfully loaded data with 2136 rows and 774 columns
## 'data.frame': 2136 obs. of 774 variables:
## $ spectrogram_id : int 686402130 959372535 250501602 1872858502 1540613004 1606814511 1694525835 1845795531 1993693261 38177271 ...
## $ spectrogram_sub_id: int 30 5 8 1 7 0 5 2 36 12 ...
## $ eeg_id : num 3.53e+08 1.76e+09 3.69e+08 1.84e+09 2.79e+09 ...
## $ eeg_sub_id : int 30 5 8 1 7 0 5 2 36 12 ...
## $ offset_seconds : num 108 18 44 4 34 0 12 22 178 26 ...
## $ expert_consensus : chr "GPD" "GRDA" "Seizure" "LPD" ...
## $ embedding_0 : num -0.186 0.534 0.611 0.503 0.359 ...
## $ embedding_1 : num -0.92 -0.552 0.656 0.861 1.12 ...
## $ embedding_2 : num 0.0599 -1.1648 -1.043 -0.7725 0.172 ...
## $ embedding_3 : num 0.368583 -0.200468 -0.045985 -0.198015 -0.000211 ...
## [list output truncated]
# Check for missing values
missing_values <- colSums(is.na(data))
if(sum(missing_values) > 0) {
cat("Missing values found in", sum(missing_values > 0), "columns\n")
} else {
cat("No missing values found in the dataset\n")
}## No missing values found in the dataset
# Convert target to factor
data$expert_consensus <- as.factor(data$expert_consensus)
# Target distribution
target_dist <- table(data$expert_consensus)
print("Class distribution:")## [1] "Class distribution:"
##
## GPD GRDA LPD LRDA Other Seizure
## 321 373 297 324 391 430
# Visualize class distribution
barplot_data <- data.frame(
Class = names(target_dist),
Count = as.numeric(target_dist)
)
ggplot(barplot_data, aes(x = Class, y = Count, fill = Class)) +
geom_bar(stat = "identity") +
geom_text(aes(label = Count), vjust = -0.5) +
scale_fill_viridis_d() +
labs(title = "Distribution of Expert Consensus Classes",
x = "Class", y = "Count") +
theme_minimal() +
theme(legend.position = "none")# Extract features and target
embedding_cols <- grep("embedding_", names(data), value = TRUE)
X <- data[, embedding_cols]
y <- data$expert_consensus
# Split data into training and testing sets (70/30)
set.seed(42)
train_index <- createDataPartition(y, p = 0.7, list = FALSE)
X_train <- X[train_index, ]
X_test <- X[-train_index, ]
y_train <- y[train_index]
y_test <- y[-train_index]
# Check embedding dimension
cat("Embedding dimension:", ncol(X), "\n")## Embedding dimension: 768
# Basic feature statistics
embedding_stats <- data.frame(
mean = colMeans(X_train),
std = apply(X_train, 2, sd),
min = apply(X_train, 2, min),
max = apply(X_train, 2, max),
median = apply(X_train, 2, median)
)
# Plot distribution statistics
par(mfrow = c(2, 2))
hist(embedding_stats$mean, main = "Embedding Mean Distribution", col = "skyblue")
hist(embedding_stats$std, main = "Embedding Std Distribution", col = "lightgreen")
hist(embedding_stats$min, main = "Embedding Min Values", col = "salmon")
hist(embedding_stats$max, main = "Embedding Max Values", col = "plum")par(mfrow = c(1, 1))
# Sample correlation matrix (for a subset of embeddings)
set.seed(123)
sample_size <- min(30, length(embedding_cols))
sample_embeddings <- sample(embedding_cols, sample_size)
corr_matrix <- cor(X_train[, sample_embeddings])
# Visualize correlation
corrplot(corr_matrix, method = "circle", type = "upper",
tl.cex = 0.6, tl.col = "black", tl.srt = 45,
title = "Sample Embedding Correlation")# Apply UMAP for visualization
set.seed(42)
umap_result <- umap(X_train, n_components = 2, n_neighbors = 15, min_dist = 0.1)
# Create visualization data
umap_df <- data.frame(
UMAP1 = umap_result$layout[,1],
UMAP2 = umap_result$layout[,2],
Class = y_train
)
# Plot UMAP visualization
ggplot(umap_df, aes(x = UMAP1, y = UMAP2, color = Class)) +
geom_point(alpha = 0.7) +
scale_color_viridis_d() +
labs(title = "UMAP Visualization of EEG Embedding Space",
subtitle = "Colored by Expert Consensus Class") +
theme_minimal()# Apply PCA for dimensionality reduction (for LR)
pca_result <- prcomp(X_train, center = TRUE, scale. = TRUE)
# Calculate variance explained
variance_explained <- summary(pca_result)$importance[2,]
cumulative_variance <- summary(pca_result)$importance[3,]
# Find number of components for 95% variance
n_components_95 <- which(cumulative_variance >= 0.95)[1]
cat("Number of PCA components for 95% variance:", n_components_95, "\n")## Number of PCA components for 95% variance: 118
# Plot variance explained
plot(cumulative_variance, type = "b", xlab = "Principal Component",
ylab = "Cumulative Proportion of Variance",
main = "PCA Cumulative Variance")
abline(h = 0.95, col = "red", lty = 2)# Prepare data for LightGBM
num_classes <- length(levels(y_train))
y_train_numeric <- as.integer(y_train) - 1 # Convert to 0-based index
dtrain <- lgb.Dataset(
data = as.matrix(X_train),
label = y_train_numeric
)
# Define parameter grid
lgb_params <- list(
objective = "multiclass",
metric = "multi_logloss",
num_class = num_classes,
verbose = -1,
feature_pre_filter = FALSE # Add this line to fix the error
)
# Function to evaluate a parameter set with CV
evaluate_params <- function(params, nrounds = 100, nfold = 5) {
full_params <- c(lgb_params, params)
cv_result <- lgb.cv(
params = full_params,
data = dtrain,
nrounds = nrounds,
nfold = nfold,
early_stopping_rounds = 20,
verbose = 0
)
# Return best score and iteration
best_iter <- which.min(cv_result$best_iter)
return(list(
score = min(cv_result$best_score),
iteration = best_iter
))
}
# Grid of parameters to try
param_grid <- list(
list(num_leaves = 31, learning_rate = 0.1, max_depth = 6,
feature_fraction = 0.8, bagging_fraction = 0.8, min_data_in_leaf = 20),
list(num_leaves = 63, learning_rate = 0.05, max_depth = 8,
feature_fraction = 0.7, bagging_fraction = 0.7, min_data_in_leaf = 30),
list(num_leaves = 127, learning_rate = 0.01, max_depth = 10,
feature_fraction = 0.9, bagging_fraction = 0.9, min_data_in_leaf = 10)
)
# Find best parameters
cat("Tuning LightGBM hyperparameters...\n")## Tuning LightGBM hyperparameters...
lgb_results <- list()
best_score <- Inf
best_params <- NULL
best_iter <- 100
for (i in 1:length(param_grid)) {
cat("Evaluating parameter set", i, "of", length(param_grid), "\n")
result <- evaluate_params(param_grid[[i]])
lgb_results[[i]] <- list(
params = param_grid[[i]],
score = result$score,
iteration = result$iteration
)
if (result$score < best_score) {
best_score <- result$score
best_params <- param_grid[[i]]
best_iter <- result$iteration
}
}## Evaluating parameter set 1 of 3
## Evaluating parameter set 2 of 3
## Evaluating parameter set 3 of 3
##
## Best LightGBM parameters:
## $num_leaves
## [1] 63
##
## $learning_rate
## [1] 0.05
##
## $max_depth
## [1] 8
##
## $feature_fraction
## [1] 0.7
##
## $bagging_fraction
## [1] 0.7
##
## $min_data_in_leaf
## [1] 30
## Best CV score: 1.364927
# Train final LightGBM model with best parameters
final_params <- c(lgb_params, best_params)
set.seed(42)
lgb_model <- lgb.train(
params = final_params,
data = dtrain,
nrounds = 100, # Typically you'd use best_iter, setting to 100 for simplicity
verbose = 0
)
# Feature importance
lgb_importance <- lgb.importance(lgb_model, percentage = TRUE)
lgb_imp_plot <- lgb.plot.importance(lgb_importance, top_n = 20, measure = "Gain")# Make predictions
lgb_probs <- predict(lgb_model, as.matrix(X_test))
lgb_prob_matrix <- matrix(lgb_probs, ncol = num_classes, byrow = TRUE)
lgb_preds <- max.col(lgb_prob_matrix) - 1 # Convert to 0-based class indices
lgb_preds <- factor(lgb_preds, levels = 0:(num_classes-1)) # Convert to factor
# Convert y_test to same factor levels for comparison
y_test_0idx <- as.integer(y_test) - 1 # Convert to 0-based index
y_test_factor <- factor(y_test_0idx, levels = 0:(num_classes-1))
# Calculate confusion matrix
lgb_cm <- confusionMatrix(lgb_preds, y_test_factor)
print("LightGBM Performance:")## [1] "LightGBM Performance:"
## Accuracy Kappa AccuracyLower AccuracyUpper AccuracyNull
## 0.167449139 0.001565432 0.139307911 0.198702152 0.201877934
## AccuracyPValue McnemarPValue
## 0.988080707 0.131046639
# Convert response to factor
train_data <- data.frame(X_train, class = y_train)
test_data <- data.frame(X_test, class = y_test)
# Define control parameters for tuning
control <- trainControl(
method = "cv",
number = 5,
search = "grid",
classProbs = TRUE,
verboseIter = FALSE
)
# Define parameter grid for random forest
rf_grid <- expand.grid(
mtry = c(floor(sqrt(ncol(X_train))), floor(ncol(X_train)/3), floor(ncol(X_train)/5)),
splitrule = c("gini", "extratrees"),
min.node.size = c(1, 5, 10)
)
# Train Random Forest
cat("Training Random Forest model...\n")## Training Random Forest model...
set.seed(42)
rf_model <- train(
x = X_train,
y = y_train, # Make sure y_train is a factor
method = "ranger",
trControl = control,
tuneGrid = rf_grid,
importance = "impurity",
num.trees = 500
)
# Display tuning results
print(rf_model)## Random Forest
##
## 1497 samples
## 768 predictor
## 6 classes: 'GPD', 'GRDA', 'LPD', 'LRDA', 'Other', 'Seizure'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 1199, 1197, 1197, 1197, 1198
## Resampling results across tuning parameters:
##
## mtry splitrule min.node.size Accuracy Kappa
## 27 gini 1 0.4822965 0.3723073
## 27 gini 5 0.4802898 0.3696975
## 27 gini 10 0.4822831 0.3719814
## 27 extratrees 1 0.4782898 0.3667707
## 27 extratrees 5 0.4802831 0.3690726
## 27 extratrees 10 0.4789609 0.3674337
## 153 gini 1 0.4829855 0.3731245
## 153 gini 5 0.4816477 0.3715242
## 153 gini 10 0.4762830 0.3649651
## 153 extratrees 1 0.4749564 0.3634977
## 153 extratrees 5 0.4742786 0.3622398
## 153 extratrees 10 0.4742942 0.3621538
## 256 gini 1 0.4876499 0.3788775
## 256 gini 5 0.4689743 0.3563254
## 256 gini 10 0.4816343 0.3713702
## 256 extratrees 1 0.4749632 0.3631573
## 256 extratrees 5 0.4769475 0.3656100
## 256 extratrees 10 0.4776164 0.3661686
##
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were mtry = 256, splitrule = gini
## and min.node.size = 1.
# Variable importance
rf_importance <- varImp(rf_model)
plot(rf_importance, top = 20, main = "Random Forest: Top 20 Variables")# Make predictions
rf_preds <- predict(rf_model, test_data)
rf_probs <- predict(rf_model, test_data, type = "prob")
# Calculate confusion matrix
rf_cm <- confusionMatrix(rf_preds, y_test)
print("Random Forest Performance:")## [1] "Random Forest Performance:"
## Accuracy Kappa AccuracyLower AccuracyUpper AccuracyNull
## 5.023474e-01 3.959477e-01 4.628639e-01 5.418092e-01 2.018779e-01
## AccuracyPValue McnemarPValue
## 5.891456e-64 8.036481e-11
# Prepare data for Logistic Regression - using PCA reduced data for computational efficiency
train_data_pca <- data.frame(X_train_pca, class = y_train)
test_data_pca <- data.frame(X_test_pca, class = y_test)
# Define parameter grid for Logistic Regression
lr_grid <- expand.grid(
decay = c(0, 0.001, 0.01, 0.1) # Regularization parameter
)
# Train Logistic Regression
cat("Training Logistic Regression model...\n")## Training Logistic Regression model...
set.seed(42)
lr_model <- train(
class ~ .,
data = train_data_pca,
method = "multinom", # Multinomial logistic regression
trControl = control,
tuneGrid = lr_grid,
trace = FALSE, # Suppress convergence messages
MaxNWts = 5000 # Increase max weights for high-dimensional data
)
# Display tuning results
print(lr_model)## Penalized Multinomial Regression
##
## 1497 samples
## 118 predictor
## 6 classes: 'GPD', 'GRDA', 'LPD', 'LRDA', 'Other', 'Seizure'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 1199, 1197, 1197, 1197, 1198
## Resampling results across tuning parameters:
##
## decay Accuracy Kappa
## 0.000 0.4048023 0.2828969
## 0.001 0.4054689 0.2837044
## 0.010 0.4054689 0.2837044
## 0.100 0.4081490 0.2870644
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was decay = 0.1.
# Make predictions
lr_preds <- predict(lr_model, test_data_pca)
lr_probs <- predict(lr_model, test_data_pca, type = "prob")
# Calculate confusion matrix
lr_cm <- confusionMatrix(lr_preds, y_test)
print("Logistic Regression Performance:")## [1] "Logistic Regression Performance:"
## Accuracy Kappa AccuracyLower AccuracyUpper AccuracyNull
## 4.303599e-01 3.124930e-01 3.915864e-01 4.697798e-01 2.018779e-01
## AccuracyPValue McnemarPValue
## 3.882984e-39 4.922007e-05
# Variable importance if available
if (any(grepl("varImp", methods(class = class(lr_model)[1])))) {
lr_importance <- varImp(lr_model)
plot(lr_importance, top = 20, main = "Logistic Regression: Top 20 Variables")
}# Collect performance metrics
models <- c("LightGBM", "Random Forest", "LR")
accuracy <- c(lgb_cm$overall["Accuracy"], rf_cm$overall["Accuracy"], lr_cm$overall["Accuracy"])
kappa <- c(lgb_cm$overall["Kappa"], rf_cm$overall["Kappa"], lr_cm$overall["Kappa"])
# Calculate F1 score for each class
calculate_f1 <- function(cm) {
f1_scores <- cm$byClass[, "F1"]
if (is.matrix(f1_scores)) {
return(diag(f1_scores))
} else {
return(f1_scores)
}
}
f1_lgb <- calculate_f1(lgb_cm)
f1_rf <- calculate_f1(rf_cm)
f1_lr <- calculate_f1(lr_cm)
# Mean F1 scores
mean_f1 <- c(mean(f1_lgb, na.rm = TRUE),
mean(f1_rf, na.rm = TRUE),
mean(f1_lr, na.rm = TRUE))
# Create comparison table
metrics_df <- data.frame(
Model = models,
Accuracy = accuracy,
Kappa = kappa,
Mean_F1 = mean_f1
)# Better function to plot confusion matrix
plot_cm <- function(cm, title) {
# Get the data from confusion matrix
cm_table <- as.data.frame(cm$table)
names(cm_table) <- c("Reference", "Prediction", "Freq")
# Ensure the class labels are properly ordered factors
all_levels <- levels(as.factor(c(as.character(cm_table$Reference),
as.character(cm_table$Prediction))))
cm_table$Reference <- factor(cm_table$Reference, levels = all_levels)
cm_table$Prediction <- factor(cm_table$Prediction, levels = all_levels)
# Calculate percentages by row (true class)
cm_table <- cm_table %>%
group_by(Reference) %>%
mutate(Total = sum(Freq),
Percentage = Freq / Total * 100) %>%
ungroup()
# Create the plot with cleaner formatting
ggplot(cm_table, aes(x = Prediction, y = Reference, fill = Percentage)) +
geom_tile() +
# Display both count and percentage
geom_text(aes(label = sprintf("%d\n(%.1f%%)", Freq, Percentage)),
color = "black", size = 3) +
# Use a better color gradient
scale_fill_gradient2(low = "white", high = "darkblue", mid = "skyblue",
midpoint = 50, limits = c(0, 100)) +
# Better labels
labs(title = title,
x = "Predicted Class",
y = "True Class") +
# Cleaner theme
theme_minimal() +
# Make sure axes labels are readable
theme(axis.text.x = element_text(angle = 45, hjust = 1, size = 10),
axis.text.y = element_text(size = 10),
plot.title = element_text(hjust = 0.5, size = 14))
}
# Plot confusion matrices
lgb_cm_plot <- plot_cm(lgb_cm, "LightGBM Confusion Matrix")
rf_cm_plot <- plot_cm(rf_cm, "Random Forest Confusion Matrix")
lr_cm_plot <- plot_cm(lr_cm, "LR Confusion Matrix")
# Display side by side
grid.arrange(lgb_cm_plot, nrow = 1)# Compare F1 scores by class
f1_by_class <- data.frame(
Class = factor(levels(y_test)),
LightGBM = f1_lgb,
RandomForest = f1_rf,
LR = f1_lr
)
# Format for plotting
f1_long <- reshape2::melt(f1_by_class, id.vars = "Class",
variable.name = "Model", value.name = "F1_Score")
# Plot F1 scores by class
ggplot(f1_long, aes(x = Class, y = F1_Score, fill = Model)) +
geom_bar(stat = "identity", position = position_dodge()) +
geom_text(aes(label = sprintf("%.2f", F1_Score)),
position = position_dodge(width = 0.9),
vjust = -0.5, size = 3) +
scale_fill_viridis_d() +
labs(title = "F1 Scores by Class and Model",
x = "Class", y = "F1 Score") +
theme_minimal() +
ylim(0, 1)# Function to calculate ROC curves and AUC for multiclass
calculate_multiclass_roc <- function(probs, true_class, model_name) {
# Convert labels if needed
if (is.factor(true_class)) {
true_class <- as.integer(true_class) - 1
}
# Vectors to store results
auc_values <- numeric(num_classes)
# Calculate for each class
par(mfrow = c(1, 1))
colors <- rainbow(num_classes)
for (i in 1:num_classes) {
# Binary classification: class i vs. rest
binary_true <- ifelse(true_class == i-1, 1, 0)
class_probs <- probs[, i]
# Calculate ROC
roc_obj <- roc(binary_true, class_probs)
auc_values[i] <- auc(roc_obj)
# Plot first curve, then add others
if (i == 1) {
plot(roc_obj, col = colors[i],
main = paste(model_name, "ROC Curves"),
lwd = 2)
} else {
plot(roc_obj, col = colors[i], add = TRUE, lwd = 2)
}
}
# Add legend
legend("bottomright",
legend = paste("Class", 0:(num_classes-1), "AUC =", round(auc_values, 3)),
col = colors, lwd = 2)
# Return AUC values
mean_auc <- mean(auc_values)
cat(model_name, "Mean AUC:", mean_auc, "\n")
return(list(auc = auc_values, mean_auc = mean_auc))
}
# Calculate ROC curves for each model
par(mfrow = c(1, 3))
# LightGBM
lgb_roc <- calculate_multiclass_roc(lgb_prob_matrix, y_test_0idx, "LightGBM")## LightGBM Mean AUC: 0.5211887
## Random Forest Mean AUC: 0.8066222
# Logistic Regression
lr_roc <- calculate_multiclass_roc(as.matrix(lr_probs), y_test, "Logistic Regression")## Logistic Regression Mean AUC: 0.7333402
par(mfrow = c(1, 1))
# Add AUC to metrics
metrics_df$Mean_AUC <- c(lgb_roc$mean_auc, rf_roc$mean_auc, lr_roc$mean_auc)
print(metrics_df)## Model Accuracy Kappa Mean_F1 Mean_AUC
## 1 LightGBM 0.1674491 0.001565432 0.1658995 0.5211887
## 2 Random Forest 0.5023474 0.395947657 0.5087884 0.8066222
## 3 LR 0.4303599 0.312492980 0.4298563 0.7333402
# Find best model
best_idx <- which.max(metrics_df$Accuracy)
best_model <- metrics_df$Model[best_idx]
cat("Best model by accuracy:", best_model, "\n")## Best model by accuracy: Random Forest
## Accuracy: 0.5023474
## Kappa: 0.3959477
## Mean F1: 0.5087884
## Mean AUC: 0.8066222
# Get best confusion matrix
best_cm <- switch(best_model,
"LightGBM" = lgb_cm,
"Random Forest" = rf_cm,
"LR" = lr_cm)
# Enhanced visualization of best model's confusion matrix
cm_df <- as.data.frame(best_cm$table)
names(cm_df) <- c("Reference", "Prediction", "Freq")
# Calculate percentages
cm_df <- cm_df %>%
group_by(Reference) %>%
mutate(
Total = sum(Freq),
Percentage = Freq / Total * 100,
Label = sprintf("%d\n(%.1f%%)", Freq, Percentage)
)
# Plot enhanced confusion matrix
ggplot(cm_df, aes(x = Prediction, y = Reference, fill = Percentage)) +
geom_tile() +
geom_text(aes(label = Label), color = "black") +
scale_fill_gradient2(low = "white", high = "darkblue", mid = "skyblue",
midpoint = 50, limits = c(0, 100)) +
labs(title = paste("Best Model:", best_model, "- Confusion Matrix"),
subtitle = paste("Overall Accuracy:", round(metrics_df$Accuracy[best_idx] * 100, 2), "%"),
x = "Predicted Class", y = "True Class") +
theme_minimal()# Analyze feature importance from the best model
if (best_model == "LightGBM") {
# Plot top features
lgb.plot.importance(lgb_importance, top_n = 20, measure = "Gain")
} else if (best_model == "Random Forest") {
# Plot RF importance
plot(varImp(rf_model), top = 20)
}# For all models, analyze feature importance patterns
if (best_model == "LightGBM" || best_model == "Random Forest") {
# Get importance data
if (best_model == "LightGBM") {
importance_data <- lgb_importance
importance_col <- "Gain"
} else {
importance_data <- varImp(rf_model)$importance
importance_data$Overall <- importance_data$Overall / sum(importance_data$Overall) * 100
importance_data$Variable <- rownames(importance_data)
importance_col <- "Overall"
}
# Identify patterns in top embeddings
top_n <- 50 # Number of top features to analyze
if (nrow(importance_data) >= top_n) {
top_features <- importance_data[1:top_n,]
# Extract embedding indices
extract_index <- function(feature_name) {
# Extract number after "embedding_"
idx <- as.numeric(gsub("embedding_", "", feature_name))
return(idx)
}
if (best_model == "LightGBM") {
embedding_indices <- sapply(top_features$Feature, extract_index)
top_features$Index <- embedding_indices
} else {
embedding_indices <- sapply(top_features$Variable, extract_index)
top_features$Index <- embedding_indices
}
# Plot distribution of important embedding indices
if (best_model == "LightGBM") {
hist(top_features$Index, breaks = 20,
main = "Distribution of Top Embedding Indices",
xlab = "Embedding Index", col = "skyblue")
} else {
hist(top_features$Index, breaks = 20,
main = "Distribution of Top Embedding Indices",
xlab = "Embedding Index", col = "skyblue")
}
}
}# Collect performance metrics
models <- c("Random Forest", "LightGBM")
accuracy <- c(
rf_cm$overall["Accuracy"],
lgb_cm$overall["Accuracy"]
)
kappa <- c(
rf_cm$overall["Kappa"],
lgb_cm$overall["Kappa"]
)
# Create comparison dataframe
model_performance <- data.frame(
Model = models,
Accuracy = accuracy,
Kappa = kappa
)
# Display table of performance metrics
print("Performance comparison:")## [1] "Performance comparison:"
## Model Accuracy Kappa
## 1 Random Forest 0.5023474 0.395947657
## 2 LightGBM 0.1674491 0.001565432
# Plot comparison
ggplot(model_performance, aes(x = Model, y = Accuracy, fill = Model)) +
geom_bar(stat = "identity") +
geom_text(aes(label = sprintf("%.4f", Accuracy)), vjust = -0.5) +
scale_fill_viridis_d() +
labs(title = "Model Accuracy Comparison",
y = "Accuracy") +
theme_minimal() +
ylim(0, 1)# Get predictions from best model
best_preds <- switch(best_model,
"LightGBM" = lgb_preds,
"Random Forest" = rf_preds,
"LR" = lr_preds)
# Convert both to character first to ensure matching factors
best_preds_char <- as.character(best_preds)
y_test_char <- as.character(y_test)
# Get all unique classes
all_classes <- unique(c(best_preds_char, y_test_char))
# Convert back to factors with the same levels
best_preds_factor <- factor(best_preds_char, levels = all_classes)
y_test_factor <- factor(y_test_char, levels = all_classes)
# Now find misclassified instances
misclassified <- which(best_preds_factor != y_test_factor)
# If there are any misclassified instances
if (length(misclassified) > 0) {
misclass_pairs <- data.frame(
True = y_test_factor[misclassified],
Predicted = best_preds_factor[misclassified]
)
# Count misclassification patterns
misclass_counts <- misclass_pairs %>%
group_by(True, Predicted) %>%
summarise(Count = n(), .groups = 'drop') %>%
arrange(desc(Count))
# Display top misclassification patterns
print("Top Misclassification Patterns:")
print(head(misclass_counts, 10))
# Visualize misclassification patterns if there are any
if (nrow(misclass_counts) > 0) {
# Take only up to 10 rows, but handle case where there are fewer
plot_data <- head(misclass_counts, min(10, nrow(misclass_counts)))
ggplot(plot_data, aes(x = paste(True, "→", Predicted), y = Count, fill = True)) +
geom_bar(stat = "identity") +
geom_text(aes(label = Count), vjust = -0.5) +
scale_fill_viridis_d() +
labs(title = "Top Misclassification Patterns",
subtitle = paste("Based on", best_model, "predictions"),
x = "Misclassification (True → Predicted)", y = "Count") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
} else {
cat("No visualization: too few misclassification patterns\n")
}
} else {
cat("No misclassified instances found!\n")
}## [1] "Top Misclassification Patterns:"
## # A tibble: 10 × 3
## True Predicted Count
## <fct> <fct> <int>
## 1 Other Seizure 30
## 2 Seizure Other 24
## 3 LPD Other 23
## 4 GRDA Other 22
## 5 LRDA Other 22
## 6 GRDA Seizure 21
## 7 GPD GRDA 21
## 8 Other GRDA 18
## 9 LRDA Seizure 17
## 10 LPD GRDA 17
# Get probability predictions from best model
best_probs <- switch(best_model,
"LightGBM" = lgb_prob_matrix,
"Random Forest" = as.matrix(rf_probs),
"LR" = as.matrix(lr_probs))
# Calculate confidence (max probability) for each prediction
confidence <- apply(best_probs, 1, max)
# Create data frame with prediction results
confidence_df <- data.frame(
True_Class = y_test_factor,
Predicted_Class = best_preds,
Confidence = confidence,
Correct = (best_preds == y_test_factor)
)
# Analyze confidence distribution for correct vs incorrect predictions
ggplot(confidence_df, aes(x = Confidence, fill = Correct)) +
geom_density(alpha = 0.7) +
scale_fill_manual(values = c("FALSE" = "firebrick", "TRUE" = "forestgreen")) +
labs(title = "Prediction Confidence Distribution",
subtitle = "Comparing correct vs. incorrect predictions",
x = "Confidence (Maximum Probability)", y = "Density") +
theme_minimal()# Analyze accuracy at different confidence thresholds
thresholds <- seq(0, 1, by = 0.05)
threshold_results <- data.frame(
Threshold = thresholds,
Accuracy = sapply(thresholds, function(t) {
filtered <- confidence_df[confidence_df$Confidence >= t, ]
if (nrow(filtered) > 0) {
return(mean(filtered$Correct))
} else {
return(NA)
}
}),
Coverage = sapply(thresholds, function(t) {
sum(confidence_df$Confidence >= t) / nrow(confidence_df)
})
)
# Plot accuracy vs threshold
ggplot(threshold_results, aes(x = Threshold)) +
geom_line(aes(y = Accuracy, color = "Accuracy"), size = 1) +
geom_line(aes(y = Coverage, color = "Coverage"), size = 1) +
scale_color_manual(values = c("Accuracy" = "forestgreen", "Coverage" = "steelblue")) +
labs(title = "Accuracy and Coverage vs. Confidence Threshold",
x = "Confidence Threshold", y = "Value") +
theme_minimal()# Find optimal threshold for high accuracy while maintaining reasonable coverage
optimal_idx <- which.max(threshold_results$Accuracy * threshold_results$Coverage)
optimal_threshold <- threshold_results$Threshold[optimal_idx]
cat("Optimal confidence threshold:", optimal_threshold, "\n")## Optimal confidence threshold: 0
## - Accuracy at this threshold: 0.5023474
## - Coverage at this threshold: 1
# Analyze per-class performance metrics
class_metrics <- data.frame(
Class = levels(y_test_factor),
Precision = best_cm$byClass[, "Precision"],
Recall = best_cm$byClass[, "Recall"],
F1 = best_cm$byClass[, "F1"],
Specificity = best_cm$byClass[, "Specificity"]
)
# Calculate class-specific confusion
class_errors <- list()
for (class_idx in 1:num_classes) {
class_label <- levels(y_test_factor)[class_idx]
# Select instances of this class
class_instances <- which(y_test_factor == class_label)
# Confusion with other classes
errors <- table(best_preds[class_instances])
class_errors[[class_label]] <- errors
}
# Print per-class metrics
print("Per-class performance metrics:")## [1] "Per-class performance metrics:"
## Class Precision Recall F1 Specificity
## Class: GPD Other 0.7763158 0.6145833 0.6860465 0.9686924
## Class: GRDA GRDA 0.4236111 0.5495495 0.4784314 0.8428030
## Class: LPD Seizure 0.6222222 0.3146067 0.4179104 0.9690909
## Class: LRDA LRDA 0.7454545 0.4226804 0.5394737 0.9741697
## Class: Other LPD 0.3701299 0.4871795 0.4206642 0.8141762
## Class: Seizure GPD 0.4545455 0.5813953 0.5102041 0.8235294
# Plot per-class metrics
metrics_long <- reshape2::melt(class_metrics, id.vars = "Class",
variable.name = "Metric", value.name = "Value")
ggplot(metrics_long, aes(x = Class, y = Value, fill = Metric)) +
geom_bar(stat = "identity", position = position_dodge()) +
geom_text(aes(label = sprintf("%.2f", Value)),
position = position_dodge(width = 0.9),
vjust = -0.5, size = 2.5) +
scale_fill_viridis_d() +
labs(title = "Performance Metrics by Class",
subtitle = paste("Using", best_model),
x = "Class", y = "Score") +
theme_minimal() +
ylim(0, 1)# Identify the most challenging class
worst_class_idx <- which.min(class_metrics$F1)
worst_class <- class_metrics$Class[worst_class_idx]
cat("Most challenging class:", worst_class, "\n")## Most challenging class: Seizure
## F1 score: 0.4179104
## Confusion pattern:
##
## GPD GRDA LPD LRDA Other Seizure
## 7 15 8 0 24 75